# -*- coding: utf-8 -*-
import datetime

from pymongo import MongoClient
from zc_core.model.items import Sku


class CatalogCompletePipeline(object):
    def __init__(self, mongo_uri, bot_name):
        self.mongo_uri = mongo_uri
        self.bot_name = bot_name
        self.client = None
        self.db = None

    @classmethod
    def from_crawler(cls, crawler):
        settings = crawler.settings
        return cls(
            mongo_uri=settings.get('MONGODB_URI'),
            bot_name=settings.get('BOT_NAME')
        )

    def open_spider(self, spider):
        _ = spider
        self.client = MongoClient(self.mongo_uri)
        # 默认初始化当前年的库
        year = str(datetime.datetime.now().year)
        self.db = self.client['{}_{}'.format(self.bot_name, year)]

        doc = self.db.get_collection('catalog_pool')
        if doc:
            sorted_pool = doc.find().sort("level")
            if sorted_pool:
                self._build_plain_cat(sorted_pool)

    def _build_plain_cat(self, pool):
        self.cat1_map = dict()
        self.cat2_map = dict()
        self.cat3_map = dict()
        for row in pool:
            if row.get('level') == 1:
                self.cat1_map[row.get('_id')] = {
                    'cat1Id': row.get('_id'),
                    'cat1Name': row.get('catalogName'),
                }
            if row.get('level') == 2:
                cat1 = self.cat1_map[str(row.get('parentId'))]
                self.cat2_map[row.get('_id')] = {
                    'cat1Id': cat1.get('cat1Id'),
                    'cat1Name': cat1.get('cat1Name'),
                    'cat2Id': row.get('_id'),
                    'cat2Name': row.get('catalogName'),
                }
            if row.get('level') == 3:
                cat2 = self.cat2_map[str(row.get('parentId'))]
                self.cat3_map[row.get('_id')] = {
                    'cat1Id': cat2.get('cat1Id'),
                    'cat1Name': cat2.get('cat1Name'),
                    'cat2Id': cat2.get('cat2Id'),
                    'cat2Name': cat2.get('cat2Name'),
                    'cat3Id': row.get('_id'),
                    'cat3Name': row.get('catalogName'),
                }

    def process_item(self, item, spider):
        # 商品数据
        if isinstance(item, Sku):
            if item.get('catalog3Id') and self.cat3_map and \
                    (not item.get('catalog1Id') or not item.get('catalog1Name')
                     or not item.get('catalog2Id') or not item.get('catalog2Name')):
                cat3 = self.cat3_map.get(item.get('catalog3Id'))
                if cat3:
                    # 一级分类编号
                    item['catalog1Id'] = cat3.get('cat1Id')
                    # 一级分类名称
                    item['catalog1Name'] = cat3.get('cat1Name')
                    # 二级分类编号
                    item['catalog2Id'] = cat3.get('cat2Id')
                    # 二级分类名称
                    item['catalog2Name'] = cat3.get('cat2Name')
                    # 三级分类名称
                    item['catalog3Name'] = cat3.get('cat3Name')
            elif item.get('catalog2Id') and self.cat2_map and \
                    (not item.get('catalog1Id') or not item.get('catalog1Name')):
                cat2 = self.cat2_map.get(item.get('catalog2Id'))
                if cat2:
                    # 一级分类编号
                    item['catalog1Id'] = cat2.get('cat1Id')
                    # 一级分类名称
                    item['catalog1Name'] = cat2.get('cat1Name')
                    # 二级分类名称
                    item['catalog2Name'] = cat2.get('cat2Name')

        return item

    def close_spider(self, spider):
        _ = spider
        self.client.close()
